import os
import re
import sys
import zipfile
from collections import UserDict

import numpy as np
import pandas as pd
import requests


# This function unzips the GEFCom2014 data zip file and extracts the 'extended'
# load forecasting competition data. Data is saved in energy.csv
def extract_data(data_dir):
    GEFCom_dir = os.path.join(data_dir, "GEFCom2014", "GEFCom2014 Data")

    GEFCom_zipfile = os.path.join(data_dir, "GEFCom2014.zip")
    if not os.path.exists(GEFCom_zipfile):
        sys.exit(
            "Download GEFCom2014.zip from https://mlftsfwp.blob.core.windows.net/mlftsfwp/GEFCom2014.zip and save it to the '{}' directory.".format(
                data_dir
            )
        )

    # unzip root directory
    zip_ref = zipfile.ZipFile(GEFCom_zipfile, "r")
    zip_ref.extractall(os.path.join(data_dir, "GEFCom2014"))
    zip_ref.close()

    # extract the extended competition data
    zip_ref = zipfile.ZipFile(os.path.join(GEFCom_dir, "GEFCom2014-E_V2.zip"), "r")
    zip_ref.extractall(os.path.join(data_dir, "GEFCom2014-E"))
    zip_ref.close()

    # load the data from Excel file
    data = pd.read_excel(
        os.path.join(data_dir, "GEFCom2014-E", "GEFCom2014-E.xlsx"), parse_dates=["Date"]
    )

    # create timestamp variable from Date and Hour
    data["timestamp"] = data["Date"].add(pd.to_timedelta(data.Hour - 1, unit="h"))
    data = data[["timestamp", "load", "T"]]
    data = data.rename(columns={"T": "temp"})

    # remove time period with no load data
    data = data[data.timestamp >= "2012-01-01"]

    # save to csv
    data.to_csv(os.path.join(data_dir, "energy.csv"), index=False)


def download_file(url):
    local_filename = re.search("^[^?]+", url.split("/")[-1]).group()
    # NOTE the stream=True parameter below
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(local_filename, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    return local_filename


def load_data(data_dir):
    """Load the GEFCom 2014 energy load data"""

    energy = pd.read_csv(
        os.path.join(data_dir, "energy.csv"), parse_dates=["timestamp"]
    )

    # Reindex the dataframe such that the dataframe has a record for every time point
    # between the minimum and maximum timestamp in the time series. This helps to
    # identify missing time periods in the data (there are none in this dataset).

    energy.index = energy["timestamp"]
    energy = energy.reindex(
        pd.date_range(min(energy["timestamp"]), max(energy["timestamp"]), freq="H")
    )
    energy = energy.drop("timestamp", axis=1)

    return energy


def mape(predictions, actuals):
    """Mean absolute percentage error"""
    return ((predictions - actuals).abs() / actuals).mean()


def create_evaluation_df(predictions, test_inputs, H, scaler):
    """Create a data frame for easy evaluation"""
    eval_df = pd.DataFrame(
        predictions, columns=["t+" + str(t) for t in range(1, H + 1)]
    )
    eval_df["timestamp"] = test_inputs.dataframe.index
    eval_df = pd.melt(
        eval_df, id_vars="timestamp", value_name="prediction", var_name="h"
    )
    eval_df["actual"] = np.transpose(test_inputs["target"]).ravel()
    eval_df[["prediction", "actual"]] = scaler.inverse_transform(
        eval_df[["prediction", "actual"]]
    )
    return eval_df


class TimeSeriesTensor(UserDict):
    """A dictionary of tensors for input into the RNN model.

    Use this class to:
      1. Shift the values of the time series to create a Pandas dataframe containing all the data
         for a single training example
      2. Discard any samples with missing values
      3. Transform this Pandas dataframe into a numpy array of shape
         (samples, time steps, features) for input into Keras

    The class takes the following parameters:
       - **dataset**: original time series
       - **target** name of the target column
       - **H**: the forecast horizon
       - **tensor_structures**: a dictionary discribing the tensor structure of the form
             { 'tensor_name' : (range(max_backward_shift, max_forward_shift), [feature, feature, ...] ) }
             if features are non-sequential and should not be shifted, use the form
             { 'tensor_name' : (None, [feature, feature, ...])}
       - **freq**: time series frequency (default 'H' - hourly)
       - **drop_incomplete**: (Boolean) whether to drop incomplete samples (default True)
    """

    def __init__(
        self, dataset, target, H, tensor_structure, freq="H", drop_incomplete=True
    ):
        self.dataset = dataset
        self.target = target
        self.tensor_structure = tensor_structure
        self.tensor_names = list(tensor_structure.keys())

        self.dataframe = self._shift_data(H, freq, drop_incomplete)
        self.data = self._df2tensors(self.dataframe)

    def _shift_data(self, H, freq, drop_incomplete):

        # Use the tensor_structures definitions to shift the features in the original dataset.
        # The result is a Pandas dataframe with multi-index columns in the hierarchy
        #     tensor - the name of the input tensor
        #     feature - the input feature to be shifted
        #     time step - the time step for the RNN in which the data is input. These labels
        #         are centred on time t. the forecast creation time
        df = self.dataset.copy()

        idx_tuples = []
        for t in range(1, H + 1):
            df["t+" + str(t)] = df[self.target].shift(t * -1, freq=freq)
            idx_tuples.append(("target", "y", "t+" + str(t)))

        for name, structure in self.tensor_structure.items():
            rng = structure[0]
            dataset_cols = structure[1]

            for col in dataset_cols:

                # do not shift non-sequential 'static' features
                if rng is None:
                    df["context_" + col] = df[col]
                    idx_tuples.append((name, col, "static"))

                else:
                    for t in rng:
                        sign = "+" if t > 0 else ""
                        shift = str(t) if t != 0 else ""
                        period = "t" + sign + shift
                        shifted_col = name + "_" + col + "_" + period
                        df[shifted_col] = df[col].shift(t * -1, freq=freq)
                        idx_tuples.append((name, col, period))

        df = df.drop(self.dataset.columns, axis=1)
        idx = pd.MultiIndex.from_tuples(
            idx_tuples, names=["tensor", "feature", "time step"]
        )
        df.columns = idx

        if drop_incomplete:
            df = df.dropna(how="any")

        return df

    def _df2tensors(self, dataframe):

        # Transform the shifted Pandas dataframe into the multidimensional numpy arrays. These
        # arrays can be used to input into the keras model and can be accessed by tensor name.
        # For example, for a TimeSeriesTensor object named "model_inputs" and a tensor named
        # "target", the input tensor can be acccessed with model_inputs['target']

        inputs = {}
        y = dataframe["target"]
        y = y.to_numpy()
        inputs["target"] = y

        for name, structure in self.tensor_structure.items():
            rng = structure[0]
            cols = structure[1]
            tensor = dataframe[name][cols].to_numpy()
            if rng is None:
                tensor = tensor.reshape(tensor.shape[0], len(cols))
            else:
                tensor = tensor.reshape(tensor.shape[0], len(cols), len(rng))
                tensor = np.transpose(tensor, axes=[0, 2, 1])
            inputs[name] = tensor

        return inputs

    def subset_data(self, new_dataframe):

        # Use this function to recreate the input tensors if the shifted dataframe
        # has been filtered.

        self.dataframe = new_dataframe
        self.data = self._df2tensors(self.dataframe)
